{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to set `batch_transform` and `output_transform`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MONAI defines many ignite style Event handlers to provide general purpose logic during training or evaluation, like: stats logging, MLFlow tracking, TensorBoard, metrics, etc.\n",
    "\n",
    "Usually, the handler has a callable arg named `batch_transform` or `output_transform` to help prepare expected data from ignite `engine.state.batch` or `engine.state.output`. This tutorial shows examples about how to set `batch_transform` and `output_transform` for different kinds of handlers with different data shapes.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, tqdm, ignite]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import tempfile\n",
    "from glob import glob\n",
    "import shutil\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "from ignite.metrics import Accuracy\n",
    "\n",
    "from monai.data import CacheDataset, create_test_image_3d, DataLoader\n",
    "from monai.engines import SupervisedEvaluator, SupervisedTrainer\n",
    "from monai.handlers import (\n",
    "    CheckpointSaver,\n",
    "    MeanDice,\n",
    "    StatsHandler,\n",
    "    TensorBoardImageHandler,\n",
    "    TensorBoardStatsHandler,\n",
    "    ValidationHandler,\n",
    "    from_engine,\n",
    ")\n",
    "from monai.inferers import SimpleInferer, SlidingWindowInferer\n",
    "from monai.losses import DiceLoss\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import (\n",
    "    Activationsd,\n",
    "    AsChannelFirstd,\n",
    "    AsDiscreted,\n",
    "    Compose,\n",
    "    KeepLargestConnectedComponentd,\n",
    "    LoadImaged,\n",
    "    RandCropByPosNegLabeld,\n",
    "    ScaleIntensityd,\n",
    "    EnsureTyped,\n",
    ")\n",
    "from monai.utils import get_torch_version_tuple\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data shape in `engine.state.batch` and `engine.state.output`\n",
    "\n",
    "All the MONAI engines and handlers inherit from **PyTorch ignite** concepts, which clearly define the engine workflow with `State` for data context and `Event-handler` mechanism to attach and trigger components. More details about the basic concepts of `engine.state` and `batch / output transform`, please refer to: https://pytorch.org/ignite/concepts.html#state.\n",
    "\n",
    "First of all, let's take a look at the possible data shape in `engine.state.batch` and `engine.state.output`.\n",
    "\n",
    "### engine.state.batch\n",
    "(1) For a common ignite program, `batch` is usually the iterable output of PyTorch DataLoader, for example: `{\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}` where `image` and `label` are batch-first arrays, `image_meta_dict` is a dictionary of meta information for the input images, every item is a batch:\n",
    "```\n",
    "image.shape = [2, 4, 64, 64, 64]  # here 2 is batch size, 4 is channels\n",
    "label.shape = [2, 3, 64, 64, 64]\n",
    "image_meta_data = {\"filename_or_obj\": [\"/data/image1.nii\", \"/data/image2.nii\"]}\n",
    "```\n",
    "\n",
    "(2) For MONAI engines, it will automatically `decollate` the batch data into a list of `channel-first` data after every iteration. For more details about `decollate`, please refer to: https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb.\n",
    "\n",
    "The `engine.state.batch` example in (1) will be decollated into a list of dictionaries:\n",
    "`[{\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}, {\"image\": Tensor, \"label\" Tensor, \"image_meta_dict\": Dict}]`.\n",
    "\n",
    "each item of the list can be:\n",
    "```\n",
    "image.shape = [3, 64, 64, 64]  # here is channel-first array for 1 image\n",
    "label.shape = [3, 64, 64, 64]\n",
    "image_meta_data = {\"filename_or_obj\": \"/data/image1.nii\"}\n",
    "```\n",
    "\n",
    "### engine.state.output\n",
    "(1) For a common ignite program, `output` is usually the output data of current iteration, for example: `{\"pred\": Tensor, \"label\": Tensor, \"loss\": scalar}` where `pred` and `label` are batch-first arrays, `loss` is a scalar value of current iteration:\n",
    "```\n",
    "pred.shape = [2, 3, 64, 64, 64]  # here 2 is batch size, 3 is channels\n",
    "label.shape = [2, 3, 64, 64, 64]\n",
    "loss = 0.4534\n",
    "```\n",
    "\n",
    "(2) For MONAI engines, it will also automatically `decollate` the output data into a list of `channel-first` data after every iteration.\n",
    "The `engine.state.output` example in (1) will be decollated into a list of dictionaries:\n",
    "`[{\"pred\": Tensor, \"label\": Tensor, \"loss\" 0.4534}, {\"pred\": Tensor, \"label\": Tensor, \"loss\" 0.4534}]`. Please note that it replicated the scalar value of `loss` to every item of the decollated list."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define `batch_transform` and `output_transform` to extract data\n",
    "\n",
    "Now let's analyze the cases of extracting data from `engine.state.batch` or `engine.state.output`. To simplify the operation, we developed a utility function `monai.handlers.from_engine` to automatically handle all the common cases.\n",
    "\n",
    "(1) To get the meta data from dictionary format `engine.state.batch`, set arg `batch_transform=lambda x: x[\"image_meta_dict\"]`.\n",
    "\n",
    "(2) To get the meta data from decollated list of dictionaries `engine.state.batch`, set arg `lambda x: [i[\"image_meta_dict\"] for i in x]` or `from_engine(\"image_meta_dict\")`.\n",
    "\n",
    "(3) Metrics usually expect a `Tuple(pred, label)` input, if `engine.state.output` is a dictionary, set arg `output_transform=lambda x: (x[\"pred\"], x[\"label\"])`. If decollated list, set arg `lambda x: ([i[\"pred\"] for i in x], [i[\"label\"] for i in x])` or `from_engine([\"pred\", \"label\"])`.\n",
    "\n",
    "(4) To get the scalar value like `loss`, if `engine.state.output` is a dictionary, set arg `lambda x: x[\"loss\"]`. If decollated list, set arg `lambda x: x[0][\"loss\"]` or `from_engine([\"loss\"], first=True)`. Please note that here we need to set `first=True` in `from_engine` to make sure only get the scalar value from 1 item."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "Now let's try to setup a real-world demo program with MONAI engines and show the setting of `batch_transform` and `output_transform` in handlers.\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified, a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root dir is: /workspace/data/medical\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(f\"root dir is: {root_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare synthetic data for test\n",
    "\n",
    "Here we generate 40 (image, label) pairs, 20 for training, 20 for validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(40):\n",
    "    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"img{i:d}.nii.gz\"))\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i:d}.nii.gz\"))\n",
    "\n",
    "images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n",
    "segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "train_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[:20], segs[:20])]\n",
    "val_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[-20:], segs[-20:])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup train / val transforms, dataset, DataLoader, post transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\", \"label\"]),\n",
    "        AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n",
    "        ScaleIntensityd(keys=\"image\"),\n",
    "        RandCropByPosNegLabeld(\n",
    "            keys=[\"image\", \"label\"], label_key=\"label\", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4\n",
    "        ),\n",
    "        EnsureTyped(keys=[\"image\", \"label\"]),\n",
    "    ]\n",
    ")\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\", \"label\"]),\n",
    "        AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n",
    "        ScaleIntensityd(keys=\"image\"),\n",
    "        EnsureTyped(keys=[\"image\", \"label\"]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# create a training data loader\n",
    "train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)\n",
    "# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training\n",
    "train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n",
    "# create a validation data loader\n",
    "val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)\n",
    "val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)\n",
    "\n",
    "val_post_transforms = Compose(\n",
    "    [\n",
    "        EnsureTyped(keys=\"pred\"),\n",
    "        Activationsd(keys=\"pred\", sigmoid=True),\n",
    "        AsDiscreted(keys=\"pred\", threshold_values=True),\n",
    "        KeepLargestConnectedComponentd(keys=\"pred\", applied_labels=[1]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup network, optimizer, loss function, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ").to(device)\n",
    "loss = DiceLoss(sigmoid=True)\n",
    "opt = torch.optim.Adam(net.parameters(), 3e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup `handlers` and `metrics` for train and validation\n",
    "\n",
    "Main topic of this tutorial, set arg `batch_transform` and `output_transform` to extract expected data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "val_handlers = [\n",
    "    # no stats for iteration of validation\n",
    "    StatsHandler(output_transform=lambda x: None),\n",
    "    TensorBoardStatsHandler(log_dir=\"./runs/\", output_transform=lambda x: None),\n",
    "    TensorBoardImageHandler(\n",
    "        log_dir=\"./runs/\",\n",
    "        batch_transform=from_engine([\"image\", \"label\"]),\n",
    "        output_transform=from_engine([\"pred\"]),\n",
    "    ),\n",
    "    CheckpointSaver(save_dir=\"./runs/\", save_dict={\"net\": net}, save_key_metric=True),\n",
    "]\n",
    "\n",
    "evaluator = SupervisedEvaluator(\n",
    "    device=device,\n",
    "    val_data_loader=val_loader,\n",
    "    network=net,\n",
    "    inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),\n",
    "    postprocessing=val_post_transforms,\n",
    "    key_val_metric={\n",
    "        \"val_mean_dice\": MeanDice(include_background=True, output_transform=from_engine([\"pred\", \"label\"]))\n",
    "    },\n",
    "    additional_metrics={\"val_acc\": Accuracy(output_transform=from_engine([\"pred\", \"label\"]))},\n",
    "    val_handlers=val_handlers,\n",
    "    # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation\n",
    "    amp=True if get_torch_version_tuple() >= (1, 6) else False,\n",
    ")\n",
    "\n",
    "train_handlers = [\n",
    "    ValidationHandler(validator=evaluator, interval=2, epoch_level=True),\n",
    "    StatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
    "    TensorBoardStatsHandler(\n",
    "        log_dir=\"./runs/\",\n",
    "        tag_name=\"train_loss\",\n",
    "        output_transform=from_engine([\"loss\"], first=True),\n",
    "    ),\n",
    "]\n",
    "\n",
    "trainer = SupervisedTrainer(\n",
    "    device=device,\n",
    "    max_epochs=5,\n",
    "    train_data_loader=train_loader,\n",
    "    network=net,\n",
    "    optimizer=opt,\n",
    "    loss_function=loss,\n",
    "    inferer=SimpleInferer(),\n",
    "    train_handlers=train_handlers,\n",
    "    # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training\n",
    "    amp=True if get_torch_version_tuple() >= (1, 6) else False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute training to verify the settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
